In [31]:
%matplotlib inline
from sklearn.datasets import fetch_mldata
import matplotlib.pyplot as plt
import numpy as np
from IPython.html.widgets import interact
from IPython.display import display
import core

In [32]:
mnist = fetch_mldata("MNIST original", data_home="./MNIST dataset")

In [33]:
mnist.target.shape


Out[33]:
(70000,)

In [38]:
def show_examples(i):
    plt.matshow(mnist.data[i].reshape((28,28)), cmap='Greys_r')
    display(mnist.target[i])
interact(show_examples, i=[60000,70000-1])


6.0

In [57]:
phipps = core.gen_network([784,40,10])
phipps


Out[57]:
[array([[-1.18045863,  0.99909873,  1.48408414, ...,  0.07700514,
          0.02055552, -1.85437819],
        [ 0.82048939, -0.6677628 ,  0.09726159, ..., -0.25809884,
         -0.81745753,  0.2215101 ],
        [-1.82209173, -0.30833717,  0.70276417, ...,  1.26639302,
         -1.2415207 ,  1.69466223],
        ..., 
        [ 0.04177277,  0.87976371, -2.34645315, ..., -0.04584682,
         -0.86120687, -0.19235447],
        [ 0.37467598,  0.55056671,  0.74541514, ..., -0.18803911,
          0.4837592 , -0.55211067],
        [-0.30821086,  0.60741   , -0.7220094 , ...,  1.05986457,
         -0.75256015, -0.67976349]]),
 array([[  9.75242254e-01,   1.41007556e-02,   5.08885126e-01,
          -8.89233039e-01,   6.96220748e-01,   1.24515760e+00,
           1.82054426e+00,  -7.69981773e-01,   1.16082585e+00,
          -7.47697535e-01,  -4.78418942e-01,  -1.15748825e+00,
           2.42877335e+00,   7.82008423e-01,   1.09819442e+00,
          -1.62097191e+00,   4.47348051e-01,   1.90487661e-02,
           9.31129681e-01,  -1.58091681e+00,   2.47950177e-01,
           2.64875066e-01,   3.45531524e-01,   8.41898548e-01,
           7.25186717e-01,  -3.43024341e-01,  -8.66679494e-01,
           1.09679531e+00,   1.62208156e+00,   7.47115079e-01,
          -9.51291590e-02,   6.70098738e-01,   6.12628265e-01,
          -3.97879537e-01,  -6.58478779e-01,  -9.71701002e-01,
           2.08484079e+00,   7.34469878e-01,  -4.60973104e-01,
          -1.45835169e+00,  -1.64901872e-02],
        [  6.85883089e-02,  -8.04937495e-01,  -8.89101505e-02,
           6.94017611e-01,  -6.35280609e-01,  -3.76968710e-01,
           5.05434530e-01,   1.80726206e+00,  -1.58980649e-01,
           1.09974245e+00,  -1.32264429e+00,   1.90657656e+00,
          -9.33108409e-01,  -1.07977908e+00,   1.96674393e-01,
          -2.17243315e+00,   9.60888274e-01,   1.20506637e-01,
          -7.45032670e-02,  -1.19349150e-01,   1.62598708e-01,
           3.81075533e-01,   1.34310184e-01,  -1.00264873e+00,
          -4.00769328e-01,   1.06994380e+00,  -1.19756697e+00,
          -1.34582355e-01,   3.49737477e-01,  -1.24913732e+00,
           7.49004377e-01,   1.64415550e-03,  -4.47172894e-01,
          -1.12367000e+00,  -5.11820922e-02,   5.85872158e-01,
           1.80032809e-01,  -3.44284483e-01,   5.50596298e-01,
           8.61070705e-01,  -2.64503394e+00],
        [ -7.17755040e-01,   1.25873950e+00,  -6.15900752e-01,
           1.88176084e+00,  -1.38957513e-01,  -8.32114546e-01,
           1.09082489e+00,   3.07008242e-01,  -2.49707996e-01,
           2.87873512e-01,  -1.80971164e+00,  -1.62605595e-01,
           4.96582846e-01,   4.28194916e-01,   9.20176978e-01,
          -7.84112068e-01,   1.02888091e+00,  -2.11363380e+00,
           2.12364776e+00,  -4.37296905e-01,   9.55451706e-01,
          -1.77394898e+00,   9.21702520e-01,  -7.37541791e-01,
           4.52255398e-01,   2.89586567e-01,  -6.90722916e-01,
           8.70102920e-01,  -6.77134006e-01,   4.06319051e-01,
          -2.02996231e+00,  -6.87213336e-01,  -4.19412398e-01,
          -7.41710604e-01,  -1.44594966e+00,  -1.12372177e+00,
          -5.78252725e-01,  -1.75459687e+00,  -1.65693468e+00,
          -5.25230809e-01,  -5.26290114e-01],
        [ -9.53162131e-01,  -1.47788376e+00,  -5.36040341e-01,
           1.90820653e+00,   2.93999947e-01,  -4.02311677e-01,
           3.08598136e-01,  -9.61287383e-01,  -1.65883699e+00,
          -9.10116003e-01,  -1.07168199e+00,  -4.07681526e-01,
          -2.87004927e-01,  -7.38379411e-01,   2.76408429e+00,
          -1.17446777e+00,   3.10023319e-01,  -3.18178327e-01,
           1.07402716e+00,  -8.17580380e-01,  -1.28361483e-01,
          -1.83752294e+00,   6.60678966e-01,  -9.61288166e-01,
          -3.64238899e-01,   4.01921252e-01,   5.08135350e-01,
          -1.06711128e+00,  -1.17119249e-01,   2.27852832e+00,
           5.60911679e-01,   9.40698356e-01,  -9.07170858e-02,
          -7.76294600e-01,   1.72900748e+00,   6.02730236e-01,
           1.92762248e+00,  -4.27226816e-01,  -1.66415506e+00,
           1.55091924e-01,  -4.54765886e-01],
        [  7.20027740e-01,   1.10064750e+00,   2.19248269e-01,
           6.73060949e-01,  -1.41953696e+00,   6.79567330e-01,
           8.70526227e-01,  -1.18067024e+00,   1.68951393e+00,
          -3.82680568e-01,   2.82609687e-01,   5.91561127e-01,
           4.64532361e-01,  -6.80904187e-01,  -4.90137425e-01,
          -6.12901203e-02,   1.75664090e+00,   2.05275775e-01,
           1.05732823e+00,   1.65174147e+00,  -1.93911139e-02,
           7.85811135e-01,  -4.86033874e-01,   1.05368517e-01,
           1.04481946e+00,   3.25097521e-02,  -6.78813295e-01,
           7.12082843e-01,   1.88914165e+00,  -1.46240649e-02,
           3.07209784e-02,  -1.29253366e+00,   1.13069790e+00,
          -5.10839439e-01,   1.25137804e+00,  -5.17807563e-02,
           7.07975740e-01,  -1.67950981e+00,   4.52969626e-01,
          -7.32450809e-01,  -5.97715714e-01],
        [ -3.07763763e+00,  -2.39689385e+00,   2.23819523e+00,
          -2.46279487e-01,   5.25017005e-01,  -5.11388204e-01,
           1.62351389e-01,   1.94517355e-01,  -1.81335620e+00,
           1.04459390e+00,  -2.00430906e-01,   7.82137929e-01,
           6.34216747e-01,   5.73420815e-01,  -2.08003095e-02,
          -1.14274717e+00,  -1.63230120e+00,  -1.50928028e+00,
           1.13730949e+00,  -5.86315859e-01,   1.34773743e+00,
           2.47043737e+00,  -1.12148644e+00,  -3.01146849e-02,
          -6.39609418e-01,   1.09726277e+00,   8.04369753e-01,
           6.68984137e-01,  -5.75887384e-01,  -4.12289309e-01,
           2.47163189e+00,   1.04854420e-01,  -1.12389633e+00,
          -5.51742943e-01,   5.66546160e-01,   1.20872164e+00,
          -1.04900076e+00,  -4.01879701e-01,   7.54082067e-01,
          -2.11575769e+00,  -5.86761088e-01],
        [ -1.73181837e-01,   2.32205903e+00,  -1.47600729e-01,
           6.04833565e-01,   1.07243579e+00,  -2.49656365e-02,
          -8.95177703e-01,  -2.82877057e-01,  -9.32747954e-01,
          -4.86338696e-01,  -9.41220140e-01,   1.50764719e-01,
           3.30122407e-01,  -4.97690278e-01,   2.08684342e-01,
           9.19951980e-01,   7.70242391e-01,  -7.67422983e-01,
           1.42813798e+00,  -5.66349541e-01,  -1.12064859e+00,
          -6.47902792e-01,  -1.35052063e+00,   2.04928533e+00,
           7.05055805e-01,  -4.95161117e-01,  -4.31363714e-01,
          -1.14731386e+00,   2.22643001e-01,   1.59124038e+00,
           3.88060600e-01,  -1.10747728e-01,  -6.17296289e-01,
           2.84934966e-01,  -1.84627412e+00,   7.34661641e-01,
           2.33893340e-02,   1.38324164e+00,   1.10011516e+00,
           1.58857752e+00,  -4.93708241e-01],
        [ -1.93376335e+00,   5.04884193e-01,   1.05666633e+00,
           8.25973188e-01,  -1.92519725e+00,   6.84861819e-02,
           6.16630412e-01,   5.78388230e-01,   2.76796947e+00,
           1.21815848e+00,   5.63377432e-03,  -7.31650783e-01,
           2.22964051e+00,   1.34501817e-01,   1.92525852e-01,
           2.38343821e-02,  -7.53133231e-02,  -8.78864832e-01,
          -2.66117090e-02,   7.64527104e-01,  -7.68422175e-01,
          -7.42194566e-01,   8.29575546e-01,   1.92687853e-01,
          -1.13093905e+00,   7.79662200e-02,  -1.59124547e+00,
           1.09491924e+00,  -9.90646528e-02,   1.91288543e+00,
          -2.31966735e-01,  -5.50366494e-01,   9.29403579e-01,
           3.27988516e-01,   1.41739704e+00,  -4.04344941e-02,
           2.24447820e-01,  -1.00858039e+00,  -1.07560443e+00,
           2.13582799e-01,  -4.20494767e-03],
        [  4.29793534e-01,  -4.67509472e-01,   5.63665989e-02,
           1.23153460e+00,  -1.82049585e-01,   3.75116733e-01,
          -7.68187321e-01,  -9.99732429e-01,   3.80840822e-01,
          -3.31721457e-01,   1.68333789e-01,   5.81615286e-01,
           4.32785867e-01,  -3.18269734e-01,  -1.72406876e+00,
          -5.04925733e-01,   5.94999764e-02,   1.17032932e+00,
           1.43017119e+00,   7.39077142e-01,  -6.97929026e-01,
          -1.61578240e-01,   2.05903424e-01,  -3.53550538e-01,
           1.23904219e+00,  -1.30681204e-02,   1.34279534e+00,
           1.07004182e+00,   2.27285518e+00,  -1.18475655e+00,
          -1.38341022e+00,  -9.60055939e-02,  -8.33514450e-01,
           3.74992393e-01,  -6.71278615e-01,  -1.16839763e-01,
           2.89831294e-01,   1.57833465e+00,  -3.07592157e-01,
           7.33935168e-01,   3.36846543e-01],
        [ -1.19832585e-01,  -7.61416242e-01,   3.37576099e-01,
           5.23897538e-01,   1.52908064e+00,  -3.69910056e-01,
          -1.31772167e+00,  -4.65631741e-01,  -4.57807040e-01,
          -1.56974155e+00,  -1.60390719e+00,  -6.57672471e-02,
          -2.01546300e+00,  -5.93323070e-01,  -2.78441156e-01,
           6.06972189e-01,  -3.48189942e-01,  -1.27090146e+00,
           3.89334447e-01,  -1.31603870e+00,  -2.92827181e-01,
           3.45542829e-01,   2.57318725e-01,  -8.99442022e-01,
           6.71602330e-01,   1.37179934e-01,  -3.39433169e-01,
           2.50900689e-03,   1.31651923e+00,  -5.53613980e-03,
           1.28153470e+00,  -2.37858769e+00,   9.14503297e-01,
           3.05983160e-01,   1.82768448e-01,  -3.23782584e+00,
           8.29916176e-01,   1.22781046e+00,   6.85636267e-02,
          -1.59385193e+00,  -7.45062918e-01]])]

In [53]:
train_data_index = np.linspace(0,60000, 60000 + 1)
np.random.seed(1)
np.random.shuffle(train_data_index)
mist_data = []
mist_target = []
for n in train_data_index:
    mist_data.append(mnist.data[int(n)])
    mist_target.append(mnist.target[int(n)])

This Cell Takes at Minimum 20 Minutes to Run


In [ ]:
%%timeit -r1 -n1
core.train_network(phipps,mnist.data, mnist.target, 10, 10000, 15, 0.01)

In [64]:
core.check_net(phipps, mnist.data, mnist.target, [60000,70000-1])


9.8009800980098
980/9999

In [52]:


In [ ]: